{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c0f57371-fbd0-4a3e-94fb-4c9c8aea956c",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91b49f99-5a9f-4bbe-a034-fb8a5f3fc71d",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd80c262-cf94-48df-b78e-c54a88a7ffb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[ignite,pyyaml]\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c36673a2-02cd-4eea-90ef-8226832c30d0",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeeee791-025e-4b1d-9dec-ebc83a8be4eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "from monai.config import print_config\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fdad73c-f1ab-4874-9e4e-af687f78801a",
   "metadata": {},
   "source": [
    "# Integrating Non-MONAI Code Into a Bundle\n",
    "\n",
    "This notebook will discuss strategies for integrating non-MONAI deep learning code into a bundle. This allows existing Pytorch workflows to be integrated into the bundle ecosystem, for example as a distributable bundle for the model zoo or some other repository like Hugging Face, or to integrate with MONAI Label. The assumption taken here is that you already have the components for preprocessing, inference, validation, and other parts of a workflow, and so the task is how to integrate these parts into MONAI types which can be embedded in config files.\n",
    "\n",
    "In the following cells we'll construct a bundle which follows the [CIFAR10 tutorial](https://github.com/pytorch/tutorials/blob/32d834139b8627eeacb5fb2862be9f095fcb0b52/beginner_source/blitz/cifar10_tutorial.py) in Pytorch's tutorials repo. A number of code components will be copied into the `scripts` directory of the bundle and linked into config files suitable to be used on the command line.\n",
    "\n",
    "We'll start with an initialised bundle with a \"scripts\" directory and provide an appropriate metadata file describing the CIFAR10 classification network we'll provide:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "eb9dc6ec-13da-4a37-8afa-28e2766b9343",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/usr/bin/tree\n",
      "\u001b[01;34mIntegrationBundle\u001b[00m\n",
      "├── \u001b[01;34mconfigs\u001b[00m\n",
      "│   └── metadata.json\n",
      "├── \u001b[01;34mdocs\u001b[00m\n",
      "│   └── README.md\n",
      "├── LICENSE\n",
      "├── \u001b[01;34mmodels\u001b[00m\n",
      "└── \u001b[01;34mscripts\u001b[00m\n",
      "    └── __init__.py\n",
      "\n",
      "4 directories, 4 files\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "python -m monai.bundle init_bundle IntegrationBundle\n",
    "rm IntegrationBundle/configs/inference.json\n",
    "mkdir IntegrationBundle/scripts\n",
    "echo \"\" > IntegrationBundle/scripts/__init__.py\n",
    "\n",
    "which tree && tree IntegrationBundle || true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b29f053b-cf16-4ffc-bbe7-d9433fdfa872",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting IntegrationBundle/configs/metadata.json\n"
     ]
    }
   ],
   "source": [
    "%%writefile IntegrationBundle/configs/metadata.json\n",
    "\n",
    "{\n",
    "    \"version\": \"0.0.1\",\n",
    "    \"changelog\": {\n",
    "        \"0.0.1\": \"Initial version\"\n",
    "    },\n",
    "    \"monai_version\": \"1.2.0\",\n",
    "    \"pytorch_version\": \"2.0.0\",\n",
    "    \"numpy_version\": \"1.23.5\",\n",
    "    \"optional_packages_version\": {\n",
    "        \"torchvision\": \"0.15.0\"\n",
    "    },\n",
    "    \"name\": \"IntegrationBundle\",\n",
    "    \"task\": \"Example Bundle\",\n",
    "    \"description\": \"This illustrates integrating non-MONAI code (CIFAR10 classification) into a bundle\",\n",
    "    \"authors\": \"Your Name Here\",\n",
    "    \"copyright\": \"Copyright (c) Your Name Here\",\n",
    "    \"data_source\": \"CIFAR10\",\n",
    "    \"data_type\": \"float32\",\n",
    "    \"intended_use\": \"This is suitable for demonstration only\",\n",
    "    \"network_data_format\": {\n",
    "        \"inputs\": {\n",
    "            \"image\": {\n",
    "                \"type\": \"image\",\n",
    "                \"format\": \"magnitude\",\n",
    "                \"modality\": \"natural\",\n",
    "                \"num_channels\": 3,\n",
    "                \"spatial_shape\": [32, 32],\n",
    "                \"dtype\": \"float32\",\n",
    "                \"value_range\": [-1, 1],\n",
    "                \"is_patch_data\": false,\n",
    "                \"channel_def\": {\n",
    "                    \"0\": \"red\",\n",
    "                    \"1\": \"green\",\n",
    "                    \"2\": \"blue\"\n",
    "                }\n",
    "            }\n",
    "        },\n",
    "        \"outputs\": {\n",
    "            \"pred\": {\n",
    "                \"type\": \"probabilities\",\n",
    "                \"format\": \"classes\",\n",
    "                \"num_channels\": 10,\n",
    "                \"spatial_shape\": [10],\n",
    "                \"dtype\": \"float32\",\n",
    "                \"value_range\": [0, 1],\n",
    "                \"is_patch_data\": false,\n",
    "                \"channel_def\": {\n",
    "                    \"0\": \"plane\",\n",
    "                    \"1\": \"car\",\n",
    "                    \"2\": \"bird\",\n",
    "                    \"3\": \"cat\",\n",
    "                    \"4\": \"deer\",\n",
    "                    \"5\": \"dog\",\n",
    "                    \"6\": \"frog\",\n",
    "                    \"7\": \"horse\",\n",
    "                    \"8\": \"ship\",\n",
    "                    \"9\": \"truck\"\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9eac927-052d-4632-966f-a87f06311b9b",
   "metadata": {},
   "source": [
    "Note that `torchvision` was added as an optional package but will be required to run the bundle. \n",
    "\n",
    "## Scripts\n",
    "\n",
    "Taking the CIFAR10 tutorial as the \"codebase\" we're using currently, which we want to convert into a bundle, we want to copy components into `scripts` from that codebase. We'll start with the network given in the tutorial:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dcdbe1ae-ea13-49cb-b5a3-3c2c78f91f2b",
   "metadata": {
    "lines_to_next_cell": 2,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing IntegrationBundle/scripts/net.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile IntegrationBundle/scripts/net.py\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6d11fac-ad12-4f47-a0cb-5c78263e1142",
   "metadata": {},
   "source": [
    "Data transforms and data loaders are provided using definitions from `torchvision`. If we assume that these aren't easily converted into MONAI types, we instead need a function to return data loaders which will be used in config files. We could adapt the existing code by simply copying it into a function returning these definitions for use in the bundle:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "189d71c5-6556-4891-a382-0adbc8f80d30",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing IntegrationBundle/scripts/transforms.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile IntegrationBundle/scripts/transforms.py\n",
    "\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),\n",
    "     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3d8f233e-495c-450c-a445-46d295ba7461",
   "metadata": {
    "lines_to_next_cell": 2,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing IntegrationBundle/scripts/dataloaders.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile IntegrationBundle/scripts/dataloaders.py\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "\n",
    "batch_size = 4\n",
    "\n",
    "\n",
    "def get_dataloader(is_training, transform):\n",
    "\n",
    "    if is_training:\n",
    "        trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
    "                                                download=True, transform=transform)\n",
    "        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n",
    "                                                  shuffle=True, num_workers=2)\n",
    "        return trainloader\n",
    "    else:\n",
    "        testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
    "                                               download=True, transform=transform)\n",
    "        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n",
    "                                                 shuffle=False, num_workers=2)\n",
    "        return testloader   "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "317e2abf-673d-4a84-9afb-187bf01da278",
   "metadata": {},
   "source": [
    "The training process in the tutorial is just a loop going through the dataset twice. The simplest adaptation for this is to wrap it in a function taking only the network and dataloader as arguments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1a836b1b-06da-4866-82a2-47d1efed5d7c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing IntegrationBundle/scripts/train.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile IntegrationBundle/scripts/train.py\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "def train(net, trainloader):\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "    for epoch in range(2): \n",
    "\n",
    "        running_loss = 0.0\n",
    "        for i, data in enumerate(trainloader, 0):\n",
    "            inputs, labels = data\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            outputs = net(inputs)\n",
    "            loss = criterion(outputs, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            running_loss += loss.item()\n",
    "            if i % 2000 == 1999:  \n",
    "                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')\n",
    "                running_loss = 0.0\n",
    "\n",
    "    print('Finished Training')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3baf799c-8f3d-4a84-aa0d-6acbe1a0d96b",
   "metadata": {},
   "source": [
    "This function will hard code all sorts of parameters like loss function, learning rate, epoch count, etc. For this example it will work but of course if you're adapting other code it would make sense to include more parameterisation to your wrapper components. \n",
    "\n",
    "## Training\n",
    "\n",
    "We can now define a training config file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0b9764a8-674c-42ae-ad4b-f2dea027bdbf",
   "metadata": {
    "lines_to_next_cell": 2,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing IntegrationBundle/configs/train.yaml\n"
     ]
    }
   ],
   "source": [
    "%%writefile IntegrationBundle/configs/train.yaml\n",
    "\n",
    "imports:\n",
    "- $import torch\n",
    "- $import scripts\n",
    "- $import scripts.net\n",
    "- $import scripts.train\n",
    "- $import scripts.transforms\n",
    "- $import scripts.dataloaders\n",
    "\n",
    "net:\n",
    "  _target_: scripts.net.Net\n",
    "\n",
    "transforms: '$scripts.transforms.transform'\n",
    "\n",
    "dataloader: '$scripts.dataloaders.get_dataloader(True, @transforms)'\n",
    "\n",
    "train:\n",
    "- $scripts.train.train(@net, @dataloader)\n",
    "- $torch.save(@net.state_dict(), './cifar_net.pth')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6c88aea-8182-44f1-853c-7d728bdae45b",
   "metadata": {},
   "source": [
    "The key concept demonstrated here is how to refer to definitions in the `scripts` directory within a config file and tie them together into a program. These definitions can be existing types or wrapper functions around existing code to make them easier to refer to here. A lot of good practice is ignored here but it shows how to adapt code into a bundle with minimal changes.\n",
    "\n",
    "Let's train something simple with this setup:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "65149911-3771-4a49-ade6-378305a4b946",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-09-11 17:28:16,125 - INFO - --- input summary of monai.bundle.scripts.run ---\n",
      "2023-09-11 17:28:16,125 - INFO - > run_id: 'train'\n",
      "2023-09-11 17:28:16,125 - INFO - > meta_file: './IntegrationBundle/configs/metadata.json'\n",
      "2023-09-11 17:28:16,125 - INFO - > config_file: './IntegrationBundle/configs/train.yaml'\n",
      "2023-09-11 17:28:16,125 - INFO - > bundle_root: './IntegrationBundle'\n",
      "2023-09-11 17:28:16,125 - INFO - ---\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Default logging file in 'configs/logging.conf' does not exist, skipping logging.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 170498071/170498071 [00:56<00:00, 3010200.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data/cifar-10-python.tar.gz to ./data\n",
      "[1,  2000] loss: 2.162\n",
      "[1,  4000] loss: 1.888\n",
      "[1,  6000] loss: 1.688\n",
      "[1,  8000] loss: 1.580\n",
      "[1, 10000] loss: 1.487\n",
      "[1, 12000] loss: 1.446\n",
      "[2,  2000] loss: 1.402\n",
      "[2,  4000] loss: 1.392\n",
      "[2,  6000] loss: 1.339\n",
      "[2,  8000] loss: 1.317\n",
      "[2, 10000] loss: 1.276\n",
      "[2, 12000] loss: 1.275\n",
      "Finished Training\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "BUNDLE=\"./IntegrationBundle\"\n",
    "\n",
    "export PYTHONPATH=$BUNDLE\n",
    "\n",
    "python -m monai.bundle run train \\\n",
    "    --bundle_root \"$BUNDLE\" \\\n",
    "    --meta_file \"$BUNDLE/configs/metadata.json\" \\\n",
    "    --config_file \"$BUNDLE/configs/train.yaml\" "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c27ba04-3271-4119-a57a-698aa7a83409",
   "metadata": {},
   "source": [
    "## Testing \n",
    "\n",
    "The second part of the tutorial script is testing the network with the test data which can again be put into a simple routine called from a config file: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fc35814e-625d-4871-ac1c-200a0cc562d9",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing IntegrationBundle/scripts/test.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile IntegrationBundle/scripts/test.py\n",
    "\n",
    "import torch\n",
    "\n",
    "def test(net, testloader):\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for data in testloader:\n",
    "            images, labels = data\n",
    "            outputs = net(images)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "\n",
    "    print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fb49aef2-9fb5-4e74-83d2-9da935e07648",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing IntegrationBundle/configs/test.yaml\n"
     ]
    }
   ],
   "source": [
    "%%writefile IntegrationBundle/configs/test.yaml\n",
    "\n",
    "imports:\n",
    "- $import torch\n",
    "- $import scripts\n",
    "- $import scripts.test\n",
    "- $import scripts.transforms\n",
    "- $import scripts.dataloaders\n",
    "\n",
    "net:\n",
    "  _target_: scripts.net.Net\n",
    "\n",
    "transforms: '$scripts.transforms.transform'\n",
    "\n",
    "dataloader: '$scripts.dataloaders.get_dataloader(False, @transforms)'\n",
    "\n",
    "test:\n",
    "- $@net.load_state_dict(torch.load('./cifar_net.pth'))\n",
    "- $scripts.test.test(@net, @dataloader)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ab171286-045c-4067-a2ea-be359168869d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-09-11 17:31:17,644 - INFO - --- input summary of monai.bundle.scripts.run ---\n",
      "2023-09-11 17:31:17,644 - INFO - > run_id: 'test'\n",
      "2023-09-11 17:31:17,644 - INFO - > meta_file: './IntegrationBundle/configs/metadata.json'\n",
      "2023-09-11 17:31:17,644 - INFO - > config_file: './IntegrationBundle/configs/test.yaml'\n",
      "2023-09-11 17:31:17,644 - INFO - > bundle_root: './IntegrationBundle'\n",
      "2023-09-11 17:31:17,644 - INFO - ---\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Default logging file in 'configs/logging.conf' does not exist, skipping logging.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Accuracy of the network on the 10000 test images: 54 %\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "BUNDLE=\"./IntegrationBundle\"\n",
    "\n",
    "export PYTHONPATH=$BUNDLE\n",
    "\n",
    "python -m monai.bundle run test \\\n",
    "    --bundle_root \"$BUNDLE\" \\\n",
    "    --meta_file \"$BUNDLE/configs/metadata.json\" \\\n",
    "    --config_file \"$BUNDLE/configs/test.yaml\" "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f218b72-734b-4b6e-93e5-990b8c647e8a",
   "metadata": {},
   "source": [
    "## Inference\n",
    "\n",
    "The original script lacked a section on inference with the network, however this is rather straight forward and so a script and config file can easily implement this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1f510a23-aa3a-4e34-81e2-b4c719d87939",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting IntegrationBundle/scripts/inference.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile IntegrationBundle/scripts/inference.py\n",
    "\n",
    "import torch\n",
    "from PIL import Image\n",
    "\n",
    "def inference(net, transforms, filenames):\n",
    "    for fn in filenames:\n",
    "        with Image.open(fn) as im:\n",
    "            tim=transforms(im)\n",
    "            outputs=net(tim[None])\n",
    "            _, predictions=torch.max(outputs, 1)\n",
    "            print(fn, predictions[0].item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7f1251be-f0dd-4cbf-8903-3f3769c8049c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting IntegrationBundle/configs/inference.yaml\n"
     ]
    }
   ],
   "source": [
    "%%writefile IntegrationBundle/configs/inference.yaml\n",
    "\n",
    "imports:\n",
    "- $import glob\n",
    "- $import torch\n",
    "- $import scripts\n",
    "- $import scripts.inference\n",
    "- $import scripts.transforms\n",
    "\n",
    "ckpt_path: './cifar_net.pth'\n",
    "\n",
    "input_dir: 'test_cifar10'\n",
    "input_files: '$sorted(glob.glob(@input_dir+''/*.*''))'\n",
    "\n",
    "net:\n",
    "  _target_: scripts.net.Net\n",
    "\n",
    "transforms: '$scripts.transforms.transform'\n",
    "\n",
    "inference:\n",
    "- $@net.load_state_dict(torch.load('./cifar_net.pth'))\n",
    "- $scripts.inference.inference(@net, @transforms, @input_files)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e14c3ea9-5d0f-4c62-9cfe-c3c02c7fe6e1",
   "metadata": {},
   "source": [
    "Here we'll create a test set of image files to load and predict on:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "cc2f063b-43f4-403e-b963-cf42b7e08637",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test_cifar10/img00.png Label: 3\n",
      "test_cifar10/img01.png Label: 8\n",
      "test_cifar10/img02.png Label: 8\n",
      "test_cifar10/img03.png Label: 0\n",
      "test_cifar10/img04.png Label: 6\n",
      "test_cifar10/img05.png Label: 6\n",
      "test_cifar10/img06.png Label: 1\n",
      "test_cifar10/img07.png Label: 6\n",
      "test_cifar10/img08.png Label: 3\n",
      "test_cifar10/img09.png Label: 1\n",
      "test_cifar10/img10.png Label: 0\n",
      "test_cifar10/img11.png Label: 9\n",
      "test_cifar10/img12.png Label: 5\n",
      "test_cifar10/img13.png Label: 7\n",
      "test_cifar10/img14.png Label: 9\n",
      "test_cifar10/img15.png Label: 8\n",
      "test_cifar10/img16.png Label: 5\n",
      "test_cifar10/img17.png Label: 7\n",
      "test_cifar10/img18.png Label: 8\n",
      "test_cifar10/img19.png Label: 6\n"
     ]
    }
   ],
   "source": [
    "root_dir = \".\"  # assuming CIFAR10 was downloaded to the current directory\n",
    "num_images = 20\n",
    "dataset = torchvision.datasets.CIFAR10(root=f\"{root_dir}/data\", train=False)\n",
    "\n",
    "!mkdir -p test_cifar10\n",
    "\n",
    "for i in range(num_images):\n",
    "    pil, label = dataset[i]\n",
    "    filename = f\"test_cifar10/img{i:02}.png\"\n",
    "    print(filename, \"Label:\", label)\n",
    "    pil.save(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "28d1230e-1d3a-4929-a266-e5f763dfde7f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-09-11 17:54:11,793 - INFO - --- input summary of monai.bundle.scripts.run ---\n",
      "2023-09-11 17:54:11,793 - INFO - > run_id: 'inference'\n",
      "2023-09-11 17:54:11,793 - INFO - > meta_file: './IntegrationBundle/configs/metadata.json'\n",
      "2023-09-11 17:54:11,793 - INFO - > config_file: './IntegrationBundle/configs/inference.yaml'\n",
      "2023-09-11 17:54:11,793 - INFO - > bundle_root: './IntegrationBundle'\n",
      "2023-09-11 17:54:11,793 - INFO - ---\n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Default logging file in 'configs/logging.conf' does not exist, skipping logging.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test_cifar10/img00.png 3\n",
      "test_cifar10/img01.png 8\n",
      "test_cifar10/img02.png 8\n",
      "test_cifar10/img03.png 0\n",
      "test_cifar10/img04.png 6\n",
      "test_cifar10/img05.png 6\n",
      "test_cifar10/img06.png 1\n",
      "test_cifar10/img07.png 4\n",
      "test_cifar10/img08.png 3\n",
      "test_cifar10/img09.png 1\n",
      "test_cifar10/img10.png 0\n",
      "test_cifar10/img11.png 9\n",
      "test_cifar10/img12.png 6\n",
      "test_cifar10/img13.png 7\n",
      "test_cifar10/img14.png 9\n",
      "test_cifar10/img15.png 1\n",
      "test_cifar10/img16.png 5\n",
      "test_cifar10/img17.png 3\n",
      "test_cifar10/img18.png 8\n",
      "test_cifar10/img19.png 4\n"
     ]
    }
   ],
   "source": [
    "%%bash\n",
    "\n",
    "BUNDLE=\"./IntegrationBundle\"\n",
    "\n",
    "export PYTHONPATH=$BUNDLE\n",
    "\n",
    "python -m monai.bundle run inference \\\n",
    "    --bundle_root \"$BUNDLE\" \\\n",
    "    --meta_file \"$BUNDLE/configs/metadata.json\" \\\n",
    "    --config_file \"$BUNDLE/configs/inference.yaml\" "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1a06d82-1a8a-4607-8620-474e89061027",
   "metadata": {},
   "source": [
    "## Adaptation Strategies\n",
    "\n",
    "This notebook has demonstrated one strategy of integrating existing code into a bundle. Code from an existing project, in this case an example script, was copied into the `scripts` directory of a bundle with added functions to make definitions easily referenced in config files. What shows up in the config files is a thin adapter layer to interface between what is expected in bundles and the codebase. \n",
    "\n",
    "It's clear that a mixed approach, where old components are replaced with MONAI types, would also work well given the simplicity of the code here. Substituting the Torchvision transforms with those from MONAI, using a `Trainer` class instead of the `train` function, and similarly testing and inference using an `Evaluator` class, would produce essentially the same results. It is up to you to determine what rewriting of code in the config scripts is justified for your codebase rather than adapting things in some way. \n",
    "\n",
    "The third approach involves a codebase which is installed as a package. If an external network with its training components is installed with `pip` for example, perhaps no code would be needed to adapt into a bundle, and you can just write config scripts to import this package and reference its definitions. Some adapter code may be needed in `scripts` but this could be like those demonstrated here, simple wrapper functions returning objects assigned to keys in config files through evaluated Python expressions. \n",
    "\n",
    "Creating a bundle compatible with other tools requires you to define specific items in the config files. For example, MONAI Label states requirements [here](https://github.com/Project-MONAI/MONAILabel/blob/c90f42c0730554e3a05af93645ae84ccdcb5e14b/monailabel/tasks/infer/bundle.py#L33) as names that must be present in `inference.json/yaml` to work with the label server. You would have to provide `network_def`, `preprocessing`, `postprocessing`, and others. This means that the code from your existing codebase would have to be divided up into these components if it isn't already, and its inputs and output would have to match what would be expected of the MONAI types typically used for these definitions. \n",
    "\n",
    "If you need to adapt your code to a bundle it's going to be very specific to your situation how integration is going to work. Using config files as adapter layers is shown here to work, but by understanding how bundles are structured and what the moving pieces are to a bundle \"program\" you can figure out your own strategy.\n",
    "\n",
    "### Adapting Data Processing\n",
    "\n",
    "One common module is data processing, either pre or post at various stages. MONAI transforms assume that Numpy arrays or Pytorch tensors, or dictionaries thereof, are the inputs and outputs to transforms. You can integrate existing transforms using `Lambda/Lambdad` to wrap a callable object within a MONAI transform rather than define your own `Transform` subclass. This does require that data have the correct type and shape. For example, if you have a function in `scripts` simply called `preprocess` which accepts a single image input as a Numpy array, this can be adapted into a transform sequence as such:\n",
    "\n",
    "```python\n",
    "train_transforms:\n",
    "- _target_: LoadImage\n",
    "  image_only: true\n",
    "- _target_: EnsureChannelFirst\n",
    "- _target_: ToNumpy\n",
    "- _target_: Lambda\n",
    "  func: '$@scripts.preprocess'\n",
    "- _target_: ToTensor\n",
    "```\n",
    "\n",
    "Minimising conversions to and from different formats would improve performance but otherwise this avoids complex rewriting of code to fit MONAI tranforms. A preprocess function which takes multiple inputs and produces multiple outputs would be more suited to a dictionary-based transform sequence but would also require adaptor code or a `MapTransform` subclass. \n",
    "\n",
    "\n",
    "## Summary and Next\n",
    "\n",
    "In this tutorial we have looked at how to adapt code to a MONAI bundle:\n",
    "* Wrapping code in thin adaptation layers\n",
    "* Using these components in config files\n",
    "* Discussion of the architectural concepts around the process of adaptation\n",
    "\n",
    "In future tutorials we shall delve into other details and strategies with MONAI bundles."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:monai1]",
   "language": "python",
   "name": "conda-env-monai1-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
